package attributes

import (
	"google.golang.org/grpc/attributes"

	"github.com/milvus-io/milvus/internal/util/sessionutil"
	"github.com/milvus-io/milvus/pkg/v2/streaming/util/types"
)

type attributesKeyType int

const (
	serverIDKey attributesKeyType = iota
	channelAssignmentInfoKey
	sessionKey
)

type Attributes = attributes.Attributes

// GetServerID returns the serverID in the given Attributes.
func GetServerID(attr *Attributes) *int64 {
	val := attr.Value(serverIDKey)
	if val == nil {
		return nil
	}
	serverID := val.(int64)
	return &serverID
}

// WithServerID returns a new Attributes containing the given serverID.
func WithServerID(attr *Attributes, serverID int64) *Attributes {
	return attr.WithValue(serverIDKey, serverID)
}

// WithChannelAssignmentInfo returns a new Attributes containing the given channelInfo.
func WithChannelAssignmentInfo(attr *Attributes, assignment *types.StreamingNodeAssignment) *attributes.Attributes {
	return attr.WithValue(channelAssignmentInfoKey, assignment).WithValue(serverIDKey, assignment.NodeInfo.ServerID)
}

// GetChannelAssignmentInfoFromAttributes get the channel info fetched from streamingcoord.
// Generated by the channel assignment discoverer and sent to channel assignment balancer.
func GetChannelAssignmentInfoFromAttributes(attrs *Attributes) *types.StreamingNodeAssignment {
	val := attrs.Value(channelAssignmentInfoKey)
	if val == nil {
		return nil
	}
	return val.(*types.StreamingNodeAssignment)
}

// WithSession returns a new Attributes containing the given session.
func WithSession(attr *Attributes, val *sessionutil.SessionRaw) *attributes.Attributes {
	return attr.WithValue(sessionKey, val).WithValue(serverIDKey, val.ServerID)
}

// GetSessionFromAttributes get session from attributes.
func GetSessionFromAttributes(attrs *Attributes) *sessionutil.SessionRaw {
	val := attrs.Value(sessionKey)
	if val == nil {
		return nil
	}
	return val.(*sessionutil.SessionRaw)
}
